联邦学习框架Flower 您所在的位置:网站首页 sunshine flower什么意思 联邦学习框架Flower

联邦学习框架Flower

2024-06-03 09:04| 来源: 网络整理| 查看: 265

在这里插入图片描述

在本教程中,我们将学习如何使用 Flower 和 PyTorch 在 CIFAR10 上训练卷积神经网络

首先,建议创建一个虚拟环境并在其中运行所有内容。

我们的示例由一台服务器和两个具有相同模型的客户端组成。

客户端负责根据其本地数据集为模型生成单独的权重更新。然后将这些更新发送到服务器,服务器将聚合它们以生成更好的模型。最后,服务器将此改进版本的模型发送回每个客户端。权重更新的完整周期称为回合。

现在我们对正在发生的事情有了大致的了解,让我们开始吧。我们首先需要安装 Flower。您可以通过运行以下命令来执行此操作:

$ pip install flwr

既然我们想使用 PyTorch 来解决计算机视觉任务,让我们继续安装 PyTorch 和 torchvision 库:

$ pip install torch torchvision Flower客户端

现在我们已经安装了所有依赖项,让我们使用两个客户端和一个服务器运行一个简单的分布式训练。我们的训练进程和网络架构基于 PyTorch 的深度学习和 PyTorch。 在名为 client.py 的文档中,导入 Flower 和 PyTorch 相关包:

from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 import flwr as fl

此外,我们在 PyTorch 中定义设备分配:

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

我们使用 PyTorch 加载 CIFAR10,这是一个流行的用于机器学习的彩色图像分类数据集。PyTorch DataLoader() 下载训练和测试数据,然后对其进行规范化。

def load_data(): """Load CIFAR-10 (training and test set).""" transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) trainset = CIFAR10(".", train=True, download=True, transform=transform) testset = CIFAR10(".", train=False, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=32, shuffle=True) testloader = DataLoader(testset, batch_size=32) num_examples = {"trainset" : len(trainset), "testset" : len(testset)} return trainloader, testloader, num_examples

使用 PyTorch 定义损失和优化器。数据集的训练是通过循环访问数据集来完成的,测量相应的损失并对其进行优化。

def train(net, trainloader, epochs): """Train the network on the training set.""" criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for _ in range(epochs): for images, labels in trainloader: images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() loss = criterion(net(images), labels) loss.backward() optimizer.step()

然后定义机器学习网络的验证。我们遍历测试集并测量测试集的损失和准确性。

def test(net, testloader): """Validate the network on the entire test set.""" criterion = torch.nn.CrossEntropyLoss() correct, total, loss = 0, 0, 0.0 with torch.no_grad(): for data in testloader: images, labels = data[0].to(DEVICE), data[1].to(DEVICE) outputs = net(images) loss += criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total return loss, accuracy

在定义了 PyTorch 机器学习模型的训练和测试后,我们将函数用于 Flower 客户端。

Flower客户将使用改编自“PyTorch:60分钟闪电战”的简单CNN:

class Net(nn.Module): def __init__(self) -> None: super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x #Load model and data net = Net().to(DEVICE) trainloader, testloader, num_examples = load_data()

使用 load_data() 加载数据集后,我们定义 Flower 接口。

Flower 服务器通过名为 Client 的接口与客户端交互。当服务器选择特定客户端进行训练时,它会通过网络发送训练指令。客户端接收这些指令并调用 Client 方法之一来运行代码(即,训练我们之前定义的神经网络)。

Flower提供了一个名为NumPyClient的方便类,当您的工作负载使用PyTorch时,它可以更轻松地实现 Client 接口。实现 NumPyClient 通常意味着定义以下方法(set_parameters是可选的):

get_parameters 将模型权重作为 NumPy ndarray 的列表返回set_parameters (optional) 使用从服务器接收的参数更新本地模型权重fit 设置本地模型权重 训练本地模型 接收更新的本地模型权重evaluate 测试本地模型 可以通过以下方式实现: class CifarClient(fl.client.NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] def set_parameters(self, parameters): params_dict = zip(net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) net.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): self.set_parameters(parameters) train(net, trainloader, epochs=1) return self.get_parameters(config={}), num_examples["trainset"], {} def evaluate(self, parameters, config): self.set_parameters(parameters) loss, accuracy = test(net, testloader) return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}

现在,我们可以创建类 CifarClient 的实例,并添加一行来实际运行此客户端:

fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient())

这就是对客户来说。我们只需要实现客户端或 NumPyClient 并调用 fl.client.start_client() 或 fl.client.start_numpy_client()。字符串“[::]:8080”告诉客户端要连接到哪个服务器。在我们的例子中,我们可以在同一台机器上运行服务器和客户端,因此我们使用“[::]:8080”。如果我们运行一个真正的联合工作负载,服务器和客户端在不同的机器上运行,那幺需要更改的只是我们指向客户端server_address。

Flower服务器端

对于简单的工作负载,我们可以启动 Flower 服务器并将所有配置可能性保留为其默认值。在名为 server.py 的文档中,导入 Flower 并启动服务器:

import flwr as fl fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3)) 训练模型

客户端和服务器都准备好了,我们现在可以运行所有内容并查看联邦学习的实际效果。 FL 系统通常有一个服务器和多个客户端。因此,我们必须首先启动服务器:

$ python server.py

服务器运行后,我们可以在不同的终端启动客户端。打开一个新终端并启动第一个客户端:

$ python client.py

打开另一个终端,启动第二个客户端:

$ python client.py

每个客户端都有自己的数据集。现在,您应该看到训练在第一个终端(启动服务器的终端)中是如何完成的:

INFO flower 2021-02-25 14:00:27,227 | app.py:76 | Flower server running (insecure, 3 rounds) INFO flower 2021-02-25 14:00:27,227 | server.py:72 | Getting initial parameters INFO flower 2021-02-25 14:01:15,881 | server.py:74 | Evaluating initial parameters INFO flower 2021-02-25 14:01:15,881 | server.py:87 | [TIME] FL starting DEBUG flower 2021-02-25 14:01:41,310 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2021-02-25 14:02:00,256 | server.py:177 | fit_round received 2 results and 0 failures DEBUG flower 2021-02-25 14:02:00,262 | server.py:139 | evaluate: strategy sampled 2 clients DEBUG flower 2021-02-25 14:02:03,047 | server.py:149 | evaluate received 2 results and 0 failures DEBUG flower 2021-02-25 14:02:03,049 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2021-02-25 14:02:23,908 | server.py:177 | fit_round received 2 results and 0 failures DEBUG flower 2021-02-25 14:02:23,915 | server.py:139 | evaluate: strategy sampled 2 clients DEBUG flower 2021-02-25 14:02:27,120 | server.py:149 | evaluate received 2 results and 0 failures DEBUG flower 2021-02-25 14:02:27,122 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2) DEBUG flower 2021-02-25 14:03:04,660 | server.py:177 | fit_round received 2 results and 0 failures DEBUG flower 2021-02-25 14:03:04,671 | server.py:139 | evaluate: strategy sampled 2 clients DEBUG flower 2021-02-25 14:03:09,273 | server.py:149 | evaluate received 2 results and 0 failures INFO flower 2021-02-25 14:03:09,273 | server.py:122 | [TIME] FL finished in 113.39180790000046 INFO flower 2021-02-25 14:03:09,274 | app.py:109 | app_fit: losses_distributed [(1, 650.9747924804688), (2, 526.2535400390625), (3, 473.76959228515625)] INFO flower 2021-02-25 14:03:09,274 | app.py:110 | app_fit: accuracies_distributed [] INFO flower 2021-02-25 14:03:09,274 | app.py:111 | app_fit: losses_centralized [] INFO flower 2021-02-25 14:03:09,274 | app.py:112 | app_fit: accuracies_centralized [] DEBUG flower 2021-02-25 14:03:09,276 | server.py:139 | evaluate: strategy sampled 2 clients DEBUG flower 2021-02-25 14:03:11,852 | server.py:149 | evaluate received 2 results and 0 failures INFO flower 2021-02-25 14:03:11,852 | app.py:121 | app_evaluate: federated loss: 473.76959228515625 INFO flower 2021-02-25 14:03:11,852 | app.py:122 | app_evaluate: results [('ipv6:[::1]:36602', EvaluateRes(loss=351.4906005859375, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6067})), ('ipv6:[::1]:36604', EvaluateRes(loss=353.92742919921875, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6005}))] INFO flower 2021-02-25 14:03:27,514 | app.py:127 | app_evaluate: failures []

祝贺!您已成功构建并运行了您的第一个联邦学习系统。此示例的完整源代码可以在 examples/quickstart_pytorch 中找到。



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有